from model_wage import micro,macro,params
from model_wage import distributions as dist
import numpy as np
from scipy.optimize import minimize
import csv
from copy import deepcopy
import nlopt  as nl
import os
module_dir = os.path.dirname(os.path.dirname(__file__))

class point:
    def __init__(self):
        self.name = None
        self.value = None
        self.ifree = None
        self.low = None
        self.up = None

class moment:
    def __init__(self):
        self.name = None
        self.data = None
        self.se = None
        self.sim = None

class initpars:
    def __init__(self,flex):
        self.flex = flex
        self.names = ['sigma','beta','phi','psi','delta_h1',
            'delta_h2','eta','tfp','price','risk']
        self.pars = []
        for p in self.names:
            this = point()
            this.name = p
            if this.name=='delta_h1':
                this.value = self.flex.delta[0]
            elif this.name=='delta_h2':
                this.value = self.flex.delta[1]
            else :
                this.value = getattr(self.flex,p)
            this.ifree = False
            this.low = this.value*0.9
            this.up = this.value*1.1
            self.pars.append(this)
        return
    def fix(self,name,value=None):
        for p in self.pars:
            if p.name == name:
                if value!=None:
                    p.value = value
                p.ifree = False
        return
    def free(self,name,value=None):
        for p in self.pars:
            if p.name == name:
                if value!=None:
                    p.value = value
                p.ifree = True
        return
    def extract_theta(self):
        theta = []
        for p in self.pars:
            if p.ifree:
                theta.append(p.value)
        return theta
    def put_theta(self,theta):
        j = 0
        for p in self.pars:
            if p.ifree:
                p.value = theta[j]
                j+=1
        return
    def extract_low(self):
        theta = []
        for p in self.pars:
            if p.ifree:
                theta.append(p.low)
        return theta
    def extract_up(self):
        theta = []
        for p in self.pars:
            if p.ifree:
                theta.append(p.up)
        return theta
    def set_flex(self):
        for p in self.pars:
            if p.name=='delta_h1':
                self.flex.delta[0] = p.value
            elif p.name=='delta_h2':
                self.flex.delta[1] = p.value
            else :
                setattr(self.flex,p.name,p.value)
        return
    def print(self):
        print('current parameter status : ')
        for p in self.pars:
            print(p.name,np.round(p.value,4),p.ifree)
        return


class msm:
    def __init__(self,country='us',initpar=None,verbose=False,nprocs=40,ge=True):
        self.country = country
        self.initpar = []
        self.verbose = verbose
        self.ge = ge
        # need load initial parameters
        if verbose: print('* Info on parameters initialized')
        if initpar!=None:
            self.initpar = initpar
            self.flex = self.initpar.flex
        else :
            self.flex = params.flexpars(country=country)
            self.initpar = initpars(self.flex)
        self.npar = len(self.initpar.pars)
        self.nfreepar = [p.ifree for p in self.initpar.pars].count(True)
        self.parnames = [p.name for p in self.initpar.pars]
        print('-- ')
        print('number of parameters: ',self.npar)
        print('number of free parameters: ',self.nfreepar)
        self.initpar.print()
        if self.country=='us':
            op = params.settings(nprocs=nprocs,nk=30,curv=0.5,maxk=190.0)
        else:
            op = params.settings(nprocs=nprocs,nk=30,curv=0.5,maxk=150.0)
        self.op = op
        inc = params.incprocess(country=self.country)
        inc.tauchen(ne=op.ne,m=2.5)
        self.inc = inc
        aux = params.auxpars(country=self.country,gamma=0.8)
        self.aux = aux
        return
    def set_moments(self,moms):
        self.moments = []
        self.nmoms = len(moms)
        print('- using these moments:')
        for n in moms.index:
            this = moment()
            this.name = n
            this.data = moms.loc[n,'mean']
            this.se = moms.loc[n,'sd']
            print(n,this.data,this.se)
            self.moments.append(this)
        return
    def criterion(self,theta,grad):
        # get solution
        self.initpar.put_theta(theta)
        self.initpar.set_flex()
        csumers = micro.bellman(options=self.op,flex=self.initpar.flex,inc=self.inc,aux=self.aux,rent=3e-2)
        stats = dist.stationary(dp=csumers,nk=100)
        self.eq = macro.equilibrium(stats=stats,initax=self.initax,inirent=1.5e-2,rent=self.ge,taxes=False)
        self.eq.solve()
        aggs = self.eq.aggregates()
        report = self.eq.healthreport()
        if self.country=='us':
            f = open(module_dir+'/model/params/sim_gdp_us.csv','w')
            f.write('{}'.format(aggs.Y))
            f.close()
        # building simulated moments
        distance = 0.0
        for m in self.moments:
            if m.name=='cshare':
                m.sim = (aggs.C + aggs.M*self.flex.price)/aggs.Y
            if m.name=='mshare':
                m.sim = aggs.M/aggs.Y*self.flex.price
            if m.name=='kshare':
                m.sim = aggs.K/aggs.Y
            if m.name=='trans_frombad':
                m.sim = report.pTransBad
            if m.name=='trans_fromgood':
                m.sim = report.pTransGood
            if m.name=='grad2':
                m.sim = report.gradient[0]
            if m.name=='grad3':
                m.sim = report.gradient[1]
            if m.name=='grad4':
                m.sim = report.gradient[2]
            if m.name=='tfp':
                if self.country=='us':
                    m.sim = 1.0
                else :
                    f = open(module_dir+'/model/params/sim_gdp_us.csv','r')
                    tfp_us = float(f.readline())
                    f.close()
                    m.sim = aggs.Y/tfp_us
            distance += ((m.data - m.sim)/m.se)**2
        print('f = ',distance,', pars = ',theta)
        print('- current state of moments (data, sim):')
        for m in self.moments:
            print(m.name,m.data,m.sim)
        del self.eq.stats.dp.optc
        del self.eq.stats.dp.optm
        del self.eq.stats.dp.value
        return distance

    def criterion_moms(self,theta):
        # get solution
        self.initpar.put_theta(theta)
        self.initpar.set_flex()
        self.eq.stats.flex = self.flex
        self.eq.stats.dp.flex = self.flex
        self.eq.solve()
        aggs = self.eq.aggregates()
        report = self.eq.healthreport()
        # building simulated moments
        moms = []
        for m in self.moments:
            if m.name=='cshare':
                m.sim = (aggs.C+self.flex.price*aggs.M)/aggs.Y
            if m.name=='mshare':
                m.sim = aggs.M/aggs.Y*self.flex.price
            if m.name=='kshare':
                m.sim = aggs.K/aggs.Y
            if m.name=='trans_frombad':
                m.sim = report.pTransBad
            if m.name=='trans_fromgood':
                m.sim = report.pTransGood
            if m.name=='grad2':
                m.sim = report.gradient[0]
            if m.name=='grad3':
                m.sim = report.gradient[1]
            if m.name=='grad4':
                m.sim = report.gradient[2]
            if m.name=='tfp':
                if self.country!='us':
                    f = open(module_dir+'/model/params/sim_gdp_us.csv','r')
                    tfp_us = float(f.readline())
                    f.close()
                    m.sim = aggs.Y/tfp_us
            moms.append(m.data - m.sim)
        del self.eq.stats.dp.optc
        del self.eq.stats.dp.optm
        del self.eq.stats.dp.value
        return np.array(moms)

    def estimate(self,maxeval=10000):
        for m in self.moments:
            if m.name == 'mshare':
                mshare = m.data
        initax =  (1.0 - params.auxpars(country=self.country).copay) * mshare / (1.0 - params.auxpars(country=self.country).alpha)
        self.initax = initax
        theta = self.initpar.extract_theta()
        low = self.initpar.extract_low()
        up = self.initpar.extract_up()
        n = self.nfreepar
        simp = np.zeros((n+1,n))
        dx = np.zeros(n)
        simp[0,:] = theta
        eps = 0.1
        j = 1
        for i,p in enumerate(self.initpar.pars):
            if p.ifree:
                if self.country=='us':
                    simp[j,:] = theta
                    if p.name=='sigma':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                    if p.name=='beta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='phi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='psi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='delta_h1':
                        dx[j-1] = 0.01#0.05
                        simp[j,j-1] += 0.05
                    if p.name=='delta_h2':
                        dx[j-1] = 0.01#0.1
                        simp[j,j-1] += 0.1
                    if p.name=='eta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='tfp':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.01
                    if p.name=='price':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                    if p.name=='risk':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                if self.country=='nl':
                    simp[j,:] = theta
                    if p.name=='sigma':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                    if p.name=='beta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='phi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='psi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='delta_h1':
                        dx[j-1] = 0.1#0.05
                        simp[j,j-1] += 0.05
                    if p.name=='delta_h2':
                        dx[j-1] = 0.1#0.1
                        simp[j,j-1] += 0.1
                    if p.name=='eta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='tfp':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.01
                    if p.name=='price':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.25
                    if p.name=='risk':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                else:
                    simp[j,:] = theta
                    if p.name=='sigma':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                    if p.name=='beta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='phi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='psi':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.1
                    if p.name=='delta_h1':
                        dx[j-1] = 0.5
                        simp[j,j-1] += 0.05
                    if p.name=='delta_h2':
                        dx[j-1] = 0.5
                        simp[j,j-1] += 0.1
                    if p.name=='eta':
                        dx[j-1] = 0.01
                        simp[j,j-1] += 0.01
                    if p.name=='tfp':
                        dx[j-1] = 0.1
                        simp[j,j-1] += 0.01
                    if p.name=='price':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                    if p.name=='risk':
                        dx[j-1] = 0.25
                        simp[j,j-1] += 0.25
                j +=1

        opt = nl.opt('LN_NEWUOA',n)
        opt.set_min_objective(self.criterion)
        opt.set_initial_step(dx)

        opt.set_maxeval(maxeval)
        opt.set_xtol_abs(1e-4)
        xopt = opt.optimize(theta)
        if opt.last_optimize_result()>0:
            self.opt_theta = xopt
            self.opt_distance = opt.last_optimum_value()
            print('estimation did converge, now computing standard errors...')
        else :
            self.opt_theta = theta
            self.opt_ditance = np.nan
            print('estimation did not converge, returns flag ',opt.last_optimize_result())

        #opt = minimize(self.criterion,theta,method='BFGS',options={'gtol':1.0} )
        #options={'initial_simplex': simp})

        self.initpar.put_theta(self.opt_theta)
        self.initpar.set_flex()
        return

    def covar(self):
        for m in self.moments:
            if m.name == 'mshare':
                mshare = m.data
        initax =  (1.0 - params.auxpars(country=self.country).copay) *mshare / (1.0 -params.auxpars(country=self.country).alpha)
        self.eq.initax = initax
        thetas = self.initpar.extract_theta()
        n = self.nfreepar
        eps = 1e-2*np.ones(n)
        if self.country=='sp':
            eps = 1e-2*np.ones(n)
        G = np.zeros((self.nmoms,n))
        mbase = self.criterion_moms(thetas)
        # compute G (matrix of derivatives)
        for k in range(n):
            thetas_up = thetas[:]
            if self.country=='sp':
                step = eps[k]
                thetas_up[k] = thetas_up[k] + step
            else:
                step = eps[k]
                thetas_up[k] = thetas_up[k]+eps[k]
            mup = self.criterion_moms(thetas_up)
            G[:,k] = (mup - mbase)/step
        # compute weight matrix
        W = np.zeros((self.nmoms,self.nmoms))
        for i,m in enumerate(self.moments):
            W[i,i] = 1/(m.se**2)
        # compute covar
        Cov = np.linalg.inv(G.transpose() @ W @ G)
        se = np.sqrt(np.diag(Cov))
        self.se = se
        return

